--- ../../plenoctree/octree/optimization.py	2024-02-13 12:39:49.535388381 -0500
+++ ../google/plenoctree/octree/optimization.py	2024-02-13 10:46:24.360645261 -0500
@@ -54,6 +54,13 @@
 from octree.nerf import datasets
 from octree.nerf import utils
 
+import sys
+sys.path.append('..')
+from my_python_utils.common_utils import *
+
+from wavelets.wavelet_octree import WaveletN3Tree
+from wavelets_evaluator_3d.constant_3d import ConstantEvaluator3D
+
 FLAGS = flags.FLAGS
 
 utils.define_flags()
@@ -74,7 +81,7 @@
     'render interval')
 flags.DEFINE_integer(
     'val_interval',
-    2,
+    1,
     'validation interval')
 flags.DEFINE_integer(
     'num_epochs',
@@ -127,12 +134,28 @@
     "If set, continues training even if validation PSNR decreases",
 )
 
+
+# our flags
+flags.DEFINE_bool(
+    "reset_optimization",
+    False,
+    "If set to True, resets optimization by setting all data to 0"
+)
+
+# our flags
+flags.DEFINE_bool(
+    "wavelet_octree",
+    False,
+    "Optimize using wavelet based octree"
+)
+
+
 device = "cuda" if torch.cuda.is_available() else "cpu"
 torch.autograd.set_detect_anomaly(True)
 
 
 def main(unused_argv):
-    utils.set_random_seed(20200823)
+    utils.set_random_seed(20200823) 
     utils.update_flags(FLAGS)
 
     def get_data(stage):
@@ -164,14 +187,39 @@
     os.makedirs(vis_dir, exist_ok=True)
 
     print('N3Tree load')
-    t = svox.N3Tree.load(FLAGS.input, map_location=device)
+    if FLAGS.wavelet_octree:
+        t = WaveletN3Tree.load(FLAGS.input, map_location=device)
+        # for now use 3D Haard 
+        wavelet_evaluator = ConstantEvaluator3D()
+        t.data_dim = wavelet_evaluator.get_data_size() + 1
+        alpha_data = t.data.detach()[...,-1:]
+        info = torch.zeros(*t.data.shape[:-1], t.data_dim)
+        info[...,-1:] = alpha_data
+        t.data = torch.nn.Parameter(info.to(device))
+        t.data.requires_grad = True
+        from svox.helpers import DataFormat
+        t.data_format = DataFormat.RGBA
+    else:
+        t = svox.N3Tree.load(FLAGS.input, map_location=device)    
+        if FLAGS.reset_optimization:
+            info = t.data.detach() # torch.zeros_like(t.data).to(device)
+            # set everything except density to 0
+            info[..., :-1] = 0
+            
+            t.data = torch.nn.Parameter(info).to(t.data.device)
+            t.data.requires_grad = True
+
     #  t.nan_to_num_()
 
     if 'llff' in FLAGS.config:
         ndc_config = svox.NDCConfig(width=W, height=H, focal=focal)
     else:
         ndc_config = None
-    r = svox.VolumeRenderer(t, step_size=FLAGS.renderer_step_size, ndc=ndc_config)
+    if FLAGS.wavelet_octree:
+        from wavelets_nerf.volume_render_octree_wavelet import VolumeRendererWavelet
+        r = VolumeRendererWavelet(wavelet_evaluator, t, step_size=FLAGS.renderer_step_size, ndc=ndc_config)
+    else:
+        r = svox.VolumeRenderer(t, step_size=FLAGS.renderer_step_size, ndc=ndc_config)
 
     if FLAGS.sgd:
         print('Using SGD, lr', FLAGS.lr)
@@ -187,12 +235,26 @@
     n_train_imgs = len(train_c2w)
     n_test_imgs = len(test_c2w)
 
+    N_to_evaluate = 5
     def run_test_step(i):
         print('Evaluating')
         with torch.no_grad():
             tpsnr = 0.0
+            js_to_evaluate = list(range(n_test_imgs))
+
+            if N_to_evaluate < n_test_imgs:
+                js_to_evaluate = np.random.choice(js_to_evaluate, N_to_evaluate)
+                
             for j, (c2w, im_gt) in enumerate(zip(test_c2w, test_gt)):
-                im = r.render_persp(c2w, height=H, width=W, fx=focal, fast=False)
+                if not j in js_to_evaluate:
+                    continue
+                if FLAGS.wavelet_octree:
+                    rays = r.persp_rays(c2w, width=800, height=800, fx=1111.111, fy=None)
+                    im = r.forward(rays, cuda=False, wavelet_evaluator=wavelet_evaluator)
+                else:
+                    im = r.render_persp(c2w, height=H, width=W, fx=focal, fast=True)
+                im = im.reshape(im_gt.shape)
+
                 im = im.cpu().clamp_(0.0, 1.0)
 
                 mse = ((im - im_gt) ** 2).mean()
@@ -203,19 +265,25 @@
                     vis = torch.cat((im_gt, im), dim=1)
                     vis = (vis * 255).numpy().astype(np.uint8)
                     imageio.imwrite(f"{vis_dir}/{i:04}_{j:04}.png", vis)
-            tpsnr /= n_test_imgs
+            tpsnr /= len(js_to_evaluate)
             return tpsnr
 
     best_validation_psnr = run_test_step(0)
     print('** initial val psnr ', best_validation_psnr)
     best_t = None
-    for i in range(FLAGS.num_epochs):
-        print('epoch', i)
+    for epoch_i in range(FLAGS.num_epochs):
         tpsnr = 0.0
-        for j, (c2w, im_gt) in tqdm(enumerate(zip(train_c2w, train_gt)), total=n_train_imgs):
-            im = r.render_persp(c2w, height=H, width=W, fx=focal, cuda=True)
+        pbar = tqdm(enumerate(zip(train_c2w, train_gt)), total=n_train_imgs)
+
+        for j, (c2w, im_gt) in pbar:
+            if FLAGS.wavelet_octree:
+                rays = r.persp_rays(c2w, width=800, height=800, fx=1111.111, fy=None)
+                im = r.forward(rays, cuda=False, wavelet_evaluator=wavelet_evaluator)
+            else:
+                im = r.render_persp(c2w, height=H, width=W, fx=focal, cuda=True)
             im_gt_ten = im_gt.to(device=device)
             im = torch.clamp(im, 0.0, 1.0)
+            im = im.reshape(im_gt.shape)
             mse = ((im - im_gt_ten) ** 2).mean()
             im_gt_ten = None
 
@@ -227,11 +295,13 @@
             #  t.data.data -= eta * t.data.grad
             psnr = -10.0 * np.log(mse.detach().cpu()) / np.log(10.0)
             tpsnr += psnr.item()
+            imshow(im, 'optimization image')
+            pbar.set_description(f'epoch {epoch_i} cur train psnr={psnr:.2f}')
         tpsnr /= n_train_imgs
         print('** train_psnr', tpsnr)
 
-        if i % FLAGS.val_interval == FLAGS.val_interval - 1 or i == FLAGS.num_epochs - 1:
-            validation_psnr = run_test_step(i + 1)
+        if epoch_i % FLAGS.val_interval == FLAGS.val_interval - 1 or epoch_i == FLAGS.num_epochs - 1:
+            validation_psnr = run_test_step(epoch_i + 1)
             print('** val psnr ', validation_psnr, 'best', best_validation_psnr)
             if validation_psnr > best_validation_psnr:
                 best_validation_psnr = validation_psnr
